import torch.nn as nn


class DCDecoder(nn.Module):
    def __init__(self, vocab, config):
        super(DCDecoder, self).__init__()
        self.config = config
        self.output = nn.Linear(in_features=config.bert_hidden_size,
                                out_features=vocab.DC_size,
                                bias=False)

    def forward(self, edu_represents):
        logits = self.output(edu_represents)
        return logits